import numpy as np
from collections import defaultdict
from typing import List, Dict, Any, Optional


class HPOEvaluationMetrics:
    def __init__(
        self,
        trial_history: List[Dict[str, Any]],
        warmup_history: Optional[List[Dict[str, Any]]] = None,
        pipeline_elapsed_time: Optional[float] = None,
        test_score: Optional[float] = None,
        best_known_optimum: Optional[float] = None,
        mode: str = "max",
    ):
        self.trial_history = trial_history or []
        self.warmup_history = warmup_history or []
        self.pipeline_elapsed_time = pipeline_elapsed_time
        self.test_score = test_score
        self.best_known_optimum = best_known_optimum
        self.mode = mode

        self.val_scores = np.array([t["score"] for t in self.trial_history])
        self.trial_times = np.array([t.get("elapsed_time", 0) for t in self.trial_history])
        self.warmup_times = np.array([t.get("elapsed_time", 0) for t in self.warmup_history])

    def best_score(self):
        if len(self.val_scores) == 0:
            return None
        return np.max(self.val_scores) if self.mode == "max" else np.min(self.val_scores)

    def best_config(self):
        if len(self.trial_history) == 0:
            return None
        idx = np.argmax(self.val_scores) if self.mode == "max" else np.argmin(self.val_scores)
        return self.trial_history[idx]["config"]

    def trial_count(self):
        return len(self.trial_history)

    def warmup_trial_count(self):
        return len(self.warmup_history)

    def warmup_duration(self):
        return float(np.sum(self.warmup_times)) if len(self.warmup_times) > 0 else 0.0

    def hpo_trial_duration(self):
        return float(np.sum(self.trial_times)) if len(self.trial_times) > 0 else 0.0

    def total_trial_duration(self):
        return self.warmup_duration() + self.hpo_trial_duration()

    def pipeline_wall_time(self):
        return float(self.pipeline_elapsed_time) if self.pipeline_elapsed_time is not None else None

    def calculate_regret_curve(self):
        if len(self.val_scores) == 0 or self.best_known_optimum is None:
            return None
        if self.mode == "max":
            best_so_far = np.maximum.accumulate(self.val_scores)
            regret = [self.best_known_optimum - b for b in best_so_far]
        else:
            best_so_far = np.minimum.accumulate(self.val_scores)
            regret = [b - self.best_known_optimum for b in best_so_far]
        return regret

    def first_half_quality(self):
        half = int(len(self.val_scores) * 0.5)
        if half == 0:
            return None, None
        if self.mode == "max":
            best_score = np.max(self.val_scores[:half])
            best_idx = np.argmax(self.val_scores[:half])
        else:
            best_score = np.min(self.val_scores[:half])
            best_idx = np.argmin(self.val_scores[:half])
        best_time = float(np.sum(self.trial_times[: best_idx + 1]))
        return best_score, best_time

    def normalized_regret(self):
        if self.best_known_optimum is None or len(self.val_scores) == 0:
            return None
        best_found = self.best_score()
        diff = (self.best_known_optimum - best_found) if self.mode == "max" else (best_found - self.best_known_optimum)
        denom = abs(self.best_known_optimum) if abs(self.best_known_optimum) > 1e-8 else 1.0
        return diff / denom

    def hyperparameter_coverage(self):
        coverage = defaultdict(set)
        for params in [t["config"] for t in self.trial_history]:
            for k, v in params.items():
                coverage[k].add(v)
        return {k: len(v) for k, v in coverage.items()}

    def summary(self):
        fhq_score, fhq_time = self.first_half_quality()
        return {
            "Best score in HPO": self.best_score(),
            "Best config": self.best_config(),
            "Test score": self.test_score,
            "Pipeline wall time (s)": round(self.pipeline_wall_time() or 0, 2),
            "Warmup duration (s)": round(self.warmup_duration(), 2),
            "Trial duration (s)": round(self.hpo_trial_duration(), 2),
            "Total trial duration (s)": round(self.total_trial_duration(), 2),
            "Warmup trial count": self.warmup_trial_count(),
            "HPO trial count": self.trial_count(),
            "First half best score": fhq_score,
            "First half best time": fhq_time,
            "Normalized Regret": self.normalized_regret(),
            "Hyperparameter coverage": self.hyperparameter_coverage(),
        }
